{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transfer Learning in Keras\n",
    "\n",
    "In this notebook, we'll cover how to load a pre-trained model (in this case, VGGNet19) and finetune it for a new task: detecting hot dogs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.applications.vgg19 import VGG19\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Flatten\n",
    "from keras.preprocessing.image import ImageDataGenerator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load the pre-trained VGG19 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "vgg19 = VGG19(include_top=False,\n",
    "              weights='imagenet',\n",
    "              input_shape=(224,224,3),\n",
    "              pooling=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Freeze all the layers in the base VGGNet19 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for layer in vgg19.layers:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Add custom classification layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Instantiate the sequential model and add the VGG19 model: \n",
    "model = Sequential()\n",
    "model.add(vgg19)\n",
    "\n",
    "# Add the custom layers atop the VGG19 model: \n",
    "model.add(Flatten(name='flattened'))\n",
    "model.add(Dropout(0.5, name='dropout'))\n",
    "model.add(Dense(2, activation='softmax', name='predictions'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compile the model for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prepare the data for training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset is available for download [here](https://www.kaggle.com/dansbecker/hot-dog-not-hot-dog/home)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## N.B.: As of Nov. 1st, 2019, the above download link no longer works. \n",
    "## Instead, uncomment the two lines below to download the data: \n",
    "# ! wget -c https://www.dropbox.com/s/86r9z1kb42422up/hot-dog-not-hot-dog.tar.gz\n",
    "# ! tar -xzf hot-dog-not-hot-dog.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Instantiate two image generator classes:\n",
    "train_datagen = ImageDataGenerator(\n",
    "    rescale=1.0/255,\n",
    "    data_format='channels_last',\n",
    "    rotation_range=30,\n",
    "    horizontal_flip=True,\n",
    "    fill_mode='reflect')\n",
    "\n",
    "valid_datagen = ImageDataGenerator(\n",
    "    rescale=1.0/255,\n",
    "    data_format='channels_last')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define the batch size:\n",
    "batch_size=32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 498 images belonging to 2 classes.\n",
      "Found 500 images belonging to 2 classes.\n"
     ]
    }
   ],
   "source": [
    "# Define the train and validation generators: \n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "    directory='./hot-dog-not-hot-dog/train',\n",
    "    target_size=(224, 224),\n",
    "    classes=['hot_dog','not_hot_dog'],\n",
    "    class_mode='categorical',\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    seed=42)\n",
    "\n",
    "valid_generator = valid_datagen.flow_from_directory(\n",
    "    directory='./hot-dog-not-hot-dog/test',\n",
    "    target_size=(224, 224),\n",
    "    classes=['hot_dog','not_hot_dog'],\n",
    "    class_mode='categorical',\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/16\n",
      "15/15 [==============================] - 43s - loss: 0.9661 - acc: 0.5729 - val_loss: 0.5618 - val_acc: 0.7146\n",
      "Epoch 2/16\n",
      "15/15 [==============================] - 5s - loss: 0.7775 - acc: 0.6486 - val_loss: 0.5116 - val_acc: 0.7393\n",
      "Epoch 3/16\n",
      "15/15 [==============================] - 4s - loss: 0.5765 - acc: 0.7410 - val_loss: 0.6632 - val_acc: 0.6966\n",
      "Epoch 4/16\n",
      "15/15 [==============================] - 4s - loss: 0.4318 - acc: 0.8123 - val_loss: 0.4298 - val_acc: 0.7949\n",
      "Epoch 5/16\n",
      "15/15 [==============================] - 4s - loss: 0.3937 - acc: 0.8076 - val_loss: 0.4536 - val_acc: 0.7885\n",
      "Epoch 6/16\n",
      "15/15 [==============================] - 4s - loss: 0.3500 - acc: 0.8524 - val_loss: 0.4060 - val_acc: 0.8120\n",
      "Epoch 7/16\n",
      "15/15 [==============================] - 4s - loss: 0.3216 - acc: 0.8514 - val_loss: 0.5474 - val_acc: 0.7438\n",
      "Epoch 8/16\n",
      "15/15 [==============================] - 4s - loss: 0.2992 - acc: 0.8571 - val_loss: 0.4914 - val_acc: 0.7714\n",
      "Epoch 9/16\n",
      "15/15 [==============================] - 4s - loss: 0.2927 - acc: 0.8806 - val_loss: 0.5161 - val_acc: 0.7607\n",
      "Epoch 10/16\n",
      "15/15 [==============================] - 4s - loss: 0.2808 - acc: 0.8941 - val_loss: 0.4618 - val_acc: 0.8056\n",
      "Epoch 11/16\n",
      "15/15 [==============================] - 4s - loss: 0.3981 - acc: 0.8306 - val_loss: 0.6498 - val_acc: 0.7286\n",
      "Epoch 12/16\n",
      "15/15 [==============================] - 4s - loss: 0.3303 - acc: 0.8619 - val_loss: 0.6572 - val_acc: 0.7521\n",
      "Epoch 13/16\n",
      "15/15 [==============================] - 4s - loss: 0.2846 - acc: 0.8639 - val_loss: 0.4795 - val_acc: 0.8104\n",
      "Epoch 14/16\n",
      "15/15 [==============================] - 4s - loss: 0.2657 - acc: 0.8774 - val_loss: 0.7141 - val_acc: 0.7115\n",
      "Epoch 15/16\n",
      "15/15 [==============================] - 4s - loss: 0.2646 - acc: 0.8717 - val_loss: 0.6161 - val_acc: 0.7479\n",
      "Epoch 16/16\n",
      "15/15 [==============================] - 4s - loss: 0.3248 - acc: 0.8472 - val_loss: 1.2689 - val_acc: 0.6261\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fed7bd6e940>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit_generator(train_generator, steps_per_epoch=15, \n",
    "                    epochs=16, validation_data=valid_generator, \n",
    "                    validation_steps=15)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
